import torch
import numpy as np

from core.rl.components.agent import BaseAgent
from core.utils.general_utils import ParamDict, map_dict, AttrDict
from core.utils.pytorch_utils import ten2ar, avg_grad_norm


class DQNAgent(BaseAgent):
    """Implements a flat (non-hierarchical) agent."""
    def __init__(self, config):
        super().__init__(config)
        self._hp = self._default_hparams().overwrite(config)
        self._eps = self._hp.init_epsilon

        # set up critic and optimizer
        self.critic = self._hp.critic(self._hp.critic_params)
        self.critic_opt = self._get_optimizer(self._hp.optimizer, self.critic, self._hp.critic_lr)

        # set up target network
        if self._hp.use_target_network:
            self.critic_target = self._hp.critic(self._hp.critic_params)
            self._copy_to_target_network(self.critic_target, self.critic)

    def _default_hparams(self):
        default_dict = ParamDict({
            'critic': None,     # Q-network class
            'critic_params': None,  # parameters for the Q-network class
            'critic_lr': 3e-4,  # learning rate for Q-network update
            'init_epsilon': 0.05,     # for epsilon-greedy exploration
            'epsilon_decay': 0.0,   # per-step reduction of epsilon
            'min_epsilon': 0.05,     # minimal epsilon value
            'use_target_network': True,     # if True, uses target network for computing target value
        })
        return super()._default_hparams().overwrite(default_dict)

    def _act(self, obs):
        """Predicts Q-value for all actions in the current state and returns action as the argmax of these."""
        assert len(obs.shape) == 1      # TODO implement batched act function
        critic_outputs = self.critic(obs[None])

        # epsilon-greedy exploration
        if self._is_train and np.random.uniform() < self._eps:
            critic_outputs.action = self._sample_rand_action(critic_outputs.q)
        else:
            critic_outputs.action = torch.argmax(critic_outputs.q, dim=-1)

        return map_dict(lambda x: x[0] if isinstance(x, torch.Tensor) else x, critic_outputs)

    def _act_rand(self):
        raise NotImplementedError

    def update(self, rollout_storage):
        """Updates Q-network."""
        for _ in range(self._hp.update_iterations):
            experience_batch = rollout_storage.sample(n_samples=self._hp.batch_size)
            critic_output = self.critic(experience_batch.observation)

            # get current q estimate for executed action
            one_hot_action = torch.eye(critic_output.q.shape[-1])[experience_batch.action]
            q_est = (critic_output.q * one_hot_action).sum(dim=1)

            # compute target q value
            with torch.no_grad():
                critic_output_next = self.critic_target(experience_batch.observation_next) if self._hp.use_target_network \
                                        else self.critic_target(experience_batch.observation_next)
                q_next = critic_output_next.q.max(dim=-1)[0]
                q_target = experience_batch.reward + (1 - experience_batch.done) * self._hp.discount_factor * q_next

            # compute critic loss
            critic_loss = (q_est - q_target).pow(2).mean()

            # update critic
            self._perform_update(critic_loss, self.critic_opt, self.critic)

            # update target network
            if self._hp.use_target_network:
                self._soft_update_target_network(self.critic_target, self.critic)

            # update epsilon
            self._update_eps()

            # logging
            info = AttrDict(  # losses
                critic_loss=critic_loss,
            )
            info.update(AttrDict(  # gradient norms
                critic_grad_norm=avg_grad_norm(self.critic),
            ))
            info.update(AttrDict(  # misc
                q_target=q_target,
                q_est=critic_output.q,
            ))
            info.update(rollout_storage.rollout_stats())  # stats of last rollouts, e.g. avg reward
            info = map_dict(ten2ar, info)

            return info

    def _update_eps(self):
        """Reduce epsilon for epsilon-greedy exploration."""
        self._eps = max(self._hp.min_epsilon, self._eps - self._hp.epsilon_decay)

    def _sample_rand_action(self, q):
        batch, n_actions = q.shape
        assert batch == 1       # TODO implement eps greedy exploration for batched act()
        rand_action = np.random.choice(n_actions, batch)
        return torch.tensor(rand_action, device=self._hp.device, dtype=torch.int64)

    def dummy_output(self):
        dummy_output = self.critic.dummy_output()
        dummy_output.update(AttrDict(action=None))
        return dummy_output

